{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recurrent Neural Networks with ``gluon``\n",
    "\n",
    "\n",
    "With gluon, now we can train the recurrent neural networks (RNNs) more neatly, such as the long short-term memory (LSTM) and the gated recurrent unit (GRU). To demonstrate the end-to-end RNN training and prediction pipeline, we take a classic problem in language modeling as a case study. Specifically, we will show how to predict the distribution of the next word given a sequence of previous words."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import packages\n",
    "\n",
    "To begin with, we need to make the following necessary imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import mxnet as mx\n",
    "from mxnet import gluon, autograd\n",
    "from mxnet.gluon import nn, rnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define classes for indexing words of the input document\n",
    "\n",
    "In a language modeling problem, we define the following classes to facilitate the routine procedures for loading document data. In the following, the ``Dictionary`` class is for word indexing: words in the documents can be converted from the string format to the integer format. \n",
    "\n",
    "In this example, we use consecutive integers to index words of the input document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Dictionary(object):\n",
    "    def __init__(self):\n",
    "        self.word2idx = {}\n",
    "        self.idx2word = []\n",
    "\n",
    "    def add_word(self, word):\n",
    "        if word not in self.word2idx:\n",
    "            self.idx2word.append(word)\n",
    "            self.word2idx[word] = len(self.idx2word) - 1\n",
    "        return self.word2idx[word]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idx2word)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ``Dictionary`` class is used by the ``Corpus`` class to index the words of the input document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Corpus(object):\n",
    "    def __init__(self, path):\n",
    "        self.dictionary = Dictionary()\n",
    "        self.train = self.tokenize(path + 'train.txt')\n",
    "        self.valid = self.tokenize(path + 'valid.txt')\n",
    "        self.test = self.tokenize(path + 'test.txt')\n",
    "\n",
    "    def tokenize(self, path):\n",
    "        \"\"\"Tokenizes a text file.\"\"\"\n",
    "        assert os.path.exists(path)\n",
    "        # Add words to the dictionary\n",
    "        with open(path, 'r') as f:\n",
    "            tokens = 0\n",
    "            for line in f:\n",
    "                words = line.split() + ['<eos>']\n",
    "                tokens += len(words)\n",
    "                for word in words:\n",
    "                    self.dictionary.add_word(word)\n",
    "\n",
    "        # Tokenize file content\n",
    "        with open(path, 'r') as f:\n",
    "            ids = np.zeros((tokens,), dtype='int32')\n",
    "            token = 0\n",
    "            for line in f:\n",
    "                words = line.split() + ['<eos>']\n",
    "                for word in words:\n",
    "                    ids[token] = self.dictionary.word2idx[word]\n",
    "                    token += 1\n",
    "\n",
    "        return mx.nd.array(ids, dtype='int32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Provide an exposition of different RNN models with ``gluon``\n",
    "\n",
    "Based on the ``gluon.Block`` class, we can make different RNN models available with the following single ``RNNModel`` class.\n",
    "\n",
    "Users can select their preferred RNN model or compare different RNN models by configuring the argument of the constructor of ``RNNModel``. We will show an example following the definition of the ``RNNModel`` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RNNModel(gluon.Block):\n",
    "    \"\"\"A model with an encoder, recurrent layer, and a decoder.\"\"\"\n",
    "\n",
    "    def __init__(self, mode, vocab_size, num_embed, num_hidden,\n",
    "                 num_layers, dropout=0.5, tie_weights=False, **kwargs):\n",
    "        super(RNNModel, self).__init__(**kwargs)\n",
    "        with self.name_scope():\n",
    "            self.drop = nn.Dropout(dropout)\n",
    "            self.encoder = nn.Embedding(vocab_size, num_embed,\n",
    "                                        weight_initializer = mx.init.Uniform(0.1))\n",
    "            if mode == 'rnn_relu':\n",
    "                self.rnn = rnn.RNN(num_hidden, num_layers, activation='relu', dropout=dropout,\n",
    "                                   input_size=num_embed)\n",
    "            elif mode == 'rnn_tanh':\n",
    "                self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout,\n",
    "                                   input_size=num_embed)\n",
    "            elif mode == 'lstm':\n",
    "                self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout,\n",
    "                                    input_size=num_embed)\n",
    "            elif mode == 'gru':\n",
    "                self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout,\n",
    "                                   input_size=num_embed)\n",
    "            else:\n",
    "                raise ValueError(\"Invalid mode %s. Options are rnn_relu, \"\n",
    "                                 \"rnn_tanh, lstm, and gru\"%mode)\n",
    "            if tie_weights:\n",
    "                self.decoder = nn.Dense(vocab_size, in_units = num_hidden,\n",
    "                                        params = self.encoder.params)\n",
    "            else:\n",
    "                self.decoder = nn.Dense(vocab_size, in_units = num_hidden)\n",
    "            self.num_hidden = num_hidden\n",
    "\n",
    "    def forward(self, inputs, hidden):\n",
    "        emb = self.drop(self.encoder(inputs))\n",
    "        output, hidden = self.rnn(emb, hidden)\n",
    "        output = self.drop(output)\n",
    "        decoded = self.decoder(output.reshape((-1, self.num_hidden)))\n",
    "        return decoded, hidden\n",
    "\n",
    "    def begin_state(self, *args, **kwargs):\n",
    "        return self.rnn.begin_state(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select an RNN model and configure parameters\n",
    "\n",
    "For demonstration purposes, we provide an arbitrary selection of the parameter values. In practice, some parameters should be more fine tuned based on the validation data set. \n",
    "\n",
    "For instance, to obtain a better performance, as reflected in a lower loss or perplexity, one can set ``args_epochs`` to a larger value.\n",
    "\n",
    "In this demonstration, LSTM is the chosen type of RNN. For other RNN options, one can replace the ``'lstm'`` string to ``'rnn_relu'``, ``'rnn_tanh'``, or ``'gru'`` as provided by the aforementioned ``gluon.Block`` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "args_data = '../data/nlp/ptb.'\n",
    "args_model = 'rnn_relu'\n",
    "args_emsize = 100\n",
    "args_nhid = 100\n",
    "args_nlayers = 2\n",
    "args_lr = 1.0\n",
    "args_clip = 0.2\n",
    "args_epochs = 1\n",
    "args_batch_size = 32\n",
    "args_bptt = 5\n",
    "args_dropout = 0.2\n",
    "args_tied = True\n",
    "args_cuda = 'store_true'\n",
    "args_log_interval = 500\n",
    "args_save = 'model.param'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data as batches\n",
    "\n",
    "We load the document data by leveraging the aforementioned ``Corpus`` class. \n",
    "\n",
    "To speed up the subsequent data flow in the RNN model, we pre-process the loaded data as batches. This procedure is defined in the following ``batchify`` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "context = mx.gpu() # this notebook takes too long on cpu\n",
    "corpus = Corpus(args_data)\n",
    "\n",
    "def batchify(data, batch_size):\n",
    "    \"\"\"Reshape data into (num_example, batch_size)\"\"\"\n",
    "    nbatch = data.shape[0] // batch_size\n",
    "    data = data[:nbatch * batch_size]\n",
    "    data = data.reshape((batch_size, nbatch)).T\n",
    "    return data\n",
    "\n",
    "train_data = batchify(corpus.train, args_batch_size).as_in_context(context)\n",
    "val_data = batchify(corpus.valid, args_batch_size).as_in_context(context)\n",
    "test_data = batchify(corpus.test, args_batch_size).as_in_context(context)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the model\n",
    "\n",
    "We go on to build the model, initialize model parameters, and configure the optimization algorithms for training the RNN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ntokens = len(corpus.dictionary)\n",
    "\n",
    "model = RNNModel(args_model, ntokens, args_emsize, args_nhid,\n",
    "                       args_nlayers, args_dropout, args_tied)\n",
    "model.collect_params().initialize(mx.init.Xavier(), ctx=context)\n",
    "trainer = gluon.Trainer(model.collect_params(), 'sgd',\n",
    "                        {'learning_rate': args_lr, 'momentum': 0, 'wd': 0})\n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model and evaluate on validation and testing data sets\n",
    "\n",
    "Now we can define functions for training and evaluating the model. The following are two helper functions that will be used during model training and evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_batch(source, i):\n",
    "    seq_len = min(args_bptt, source.shape[0] - 1 - i)\n",
    "    data = source[i : i + seq_len]\n",
    "    target = source[i + 1 : i + 1 + seq_len]\n",
    "    return data, target.reshape((-1,))\n",
    "\n",
    "def detach(hidden):\n",
    "    if isinstance(hidden, (tuple, list)):\n",
    "        hidden = [i.detach() for i in hidden]\n",
    "    else:\n",
    "        hidden = hidden.detach()\n",
    "    return hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is the function for model evaluation. It returns the loss of the model prediction. We will discuss the details of the loss measure shortly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def eval(data_source):\n",
    "    total_L = 0.0\n",
    "    ntotal = 0\n",
    "    hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx=context)\n",
    "    for i in range(0, data_source.shape[0] - 1, args_bptt):\n",
    "        data, target = get_batch(data_source, i)\n",
    "        output, hidden = model(data, hidden)\n",
    "        L = loss(output, target)\n",
    "        total_L += mx.nd.sum(L).asscalar()\n",
    "        ntotal += L.size\n",
    "    return total_L / ntotal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to define the function for training the model. We can monitor the model performance on the training, validation, and testing data sets over iterations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    best_val = float(\"Inf\")\n",
    "    for epoch in range(args_epochs):\n",
    "        total_L = 0.0\n",
    "        start_time = time.time()\n",
    "        hidden = model.begin_state(func = mx.nd.zeros, batch_size = args_batch_size, ctx = context)\n",
    "        for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args_bptt)):\n",
    "            data, target = get_batch(train_data, i)\n",
    "            hidden = detach(hidden)\n",
    "            with autograd.record():\n",
    "                output, hidden = model(data, hidden)\n",
    "                L = loss(output, target)\n",
    "                L.backward()\n",
    "\n",
    "            grads = [i.grad(context) for i in model.collect_params().values()]\n",
    "            # Here gradient is for the whole batch.\n",
    "            # So we multiply max_norm by batch_size and bptt size to balance it.\n",
    "            gluon.utils.clip_global_norm(grads, args_clip * args_bptt * args_batch_size)\n",
    "\n",
    "            trainer.step(args_batch_size)\n",
    "            total_L += mx.nd.sum(L).asscalar()\n",
    "\n",
    "            if ibatch % args_log_interval == 0 and ibatch > 0:\n",
    "                cur_L = total_L / args_bptt / args_batch_size / args_log_interval\n",
    "                print('[Epoch %d Batch %d] loss %.2f, perplexity %.2f' % (\n",
    "                    epoch + 1, ibatch, cur_L, math.exp(cur_L)))\n",
    "                total_L = 0.0\n",
    "\n",
    "        val_L = eval(val_data)\n",
    "\n",
    "        print('[Epoch %d] time cost %.2fs, validation loss %.2f, validation perplexity %.2f' % (\n",
    "            epoch + 1, time.time() - start_time, val_L, math.exp(val_L)))\n",
    "\n",
    "        if val_L < best_val:\n",
    "            best_val = val_L\n",
    "            test_L = eval(test_data)\n",
    "            model.save_parameters(args_save)\n",
    "            print('test loss %.2f, test perplexity %.2f' % (test_L, math.exp(test_L)))\n",
    "        else:\n",
    "            args_lr = args_lr * 0.25\n",
    "            trainer._init_optimizer('sgd',\n",
    "                                    {'learning_rate': args_lr,\n",
    "                                     'momentum': 0,\n",
    "                                     'wd': 0})\n",
    "            model.load_parameters(args_save, context)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall that the RNN model training is based on maximization likelihood of observations. For evaluation purposes, we have used the following two measures:\n",
    "\n",
    "* Loss: the loss function is defined as the average negative log likelihood of the target words (ground truth) under prediction: $$\\text{loss} = -\\frac{1}{N} \\sum_{i = 1}^N \\text{log} \\  p_{\\text{target}_i}, $$ where $N$ is the number of predictions and $p_{\\text{target}_i}$ the predicted likelihood of the $i$-th target word.\n",
    "\n",
    "* Perplexity: the average per-word perplexity is $\\text{exp}(\\text{loss})$.\n",
    "\n",
    "To orient the reader using concrete examples, let us illustrate the idea of the perplexity measure as follows.\n",
    "\n",
    "* Consider the perfect scenario where the model always predicts the likelihood of the target word as 1. In this case, for every $i$ we have $p_{\\text{target}_i} = 1$. As a result, the perplexity of the perfect model is 1. \n",
    "\n",
    "* Consider a baseline scenario where the model always predicts the likelihood of the target word randomly at uniform among the given word set $W$. In this case, for every $i$ we have $p_{\\text{target}_i} = 1 / |W|$. As a result, the perplexity of a uniformly random prediction model is always $|W|$. \n",
    "\n",
    "* Consider the worst-case scenario where the model always predicts the likelihood of the target word as 0. In this case, for every $i$ we have $p_{\\text{target}_i} = 0$. As a result, the perplexity of the worst model is positive infinity. \n",
    "\n",
    "\n",
    "Therefore, a model with a lower perplexity that is closer to 1 is generally more effective. Any effective model has to achieve a perplexity lower than the cardinality of the target set.\n",
    "\n",
    "Now we are ready to train the model and evaluate the model performance on validation and testing data sets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train()\n",
    "model.load_parameters(args_save, context)\n",
    "test_L = eval(test_data)\n",
    "print('Best test loss %.2f, test perplexity %.2f'%(test_L, math.exp(test_L)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Introduction to optimization](../chapter06_optimization/optimization-intro.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
